import torch.nn as nn

# 初始化模型权重
def init_weights(model):
  for m in model.modules():
      if isinstance(m, nn.Conv2d):
          nn.init.kaiming_uniform_(m.weight)
          if m.bias is not None:
              nn.init.zeros_(m.bias)
      elif isinstance(m, nn.Linear):
          nn.init.kaiming_uniform_(m.weight)
          nn.init.zeros_(m.bias)
